import numpy as np
import matplotlib.pyplot as plt
import argparse
from itertools import product

# Set up argument parser
parser = argparse.ArgumentParser(description='Compare total variation distance histograms across different t and f values')
parser.add_argument('--t_values', type=int, nargs='+', default=[4, 8, 16, 32, 128], help='List of t values to compare (default: [4, 8, 16])')
parser.add_argument('--f_values', type=int, nargs='+', default=[2, 4, 8, 16], help='List of f values to compare (default: [2, 4, 8])')
parser.add_argument('--method', type=str, choices=['lsh', 'lsh_dual', 'lsh_whitening', 'lsh_whitening_dual'], 
                    default='lsh_dual', help='Which method to compare (default: lsh_dual)')
parser.add_argument('--base_path', type=str, default='~/results/sae-softmax/gemma-2-2b/', help='Base path for data files')

args = parser.parse_args()

# Create figure
plt.figure(figsize=(15, 10))

# Define colors for different combinations
colors = plt.cm.viridis(np.linspace(0, 1, len(args.t_values) * len(args.f_values)))

# Plot histograms for each t/f combination
color_idx = 0
for t, f in product(args.t_values, args.f_values):
    # Construct filename based on method
    if args.method == 'lsh':
        filename = f'tv_distances_lsh_t{t}_f{f}.npy'
    elif args.method == 'lsh_dual':
        filename = f'tv_distances_lsh_dual_t{t}_f{f}.npy'
    elif args.method == 'lsh_whitening':
        filename = f'tv_distances_lsh_whitening_t{t}_f{f}.npy'
    else:  # lsh_dual_whitening
        filename = f'tv_distances_lsh_whitening_dual_t{t}_f{f}.npy'

    try:
        data = np.load(f'{args.base_path}{filename}')[:1000]
        plt.hist(data, bins=50, alpha=0.5, 
                label=f't={t}, f={f}', 
                color=colors[color_idx])
        color_idx += 1
    except FileNotFoundError:
        print(f"Warning: Could not find file for t={t}, f={f}")
        continue

# Add labels and title
plt.xlabel('Total Variation Distance')
plt.ylabel('Frequency')
plt.title(f'Comparison of Total Variation Distances for {args.method} with different t/f values')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Save the plot
plt.savefig(f'tv_distance_comparison_{args.method}.png', dpi=300, bbox_inches='tight')
plt.close()

# Print summary statistics
print("\nSummary Statistics:")
print("-" * 50)
for t, f in product(args.t_values, args.f_values):
    filename = f'tv_distances_{args.method}_t{t}_f{f}.npy'
    try:
        data = np.load(f'{args.base_path}{filename}')
        print(f"\nt={t}, f={f}:")
        print(f"Mean: {np.mean(data):.4f}")
        print(f"Std:  {np.std(data):.4f}")
        print(f"Min:  {np.min(data):.4f}")
        print(f"Max:  {np.max(data):.4f}")
    except FileNotFoundError:
        print(f"\nt={t}, f={f}: File not found")
        continue 